"""
Chronos-2
---------

For detailed examples and tutorials, see:

* `Chronos-2 Foundation Model Examples
  <https://unit8co.github.io/darts/examples/25-Chronos-2-examples.html>`__
"""

import math
import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union, cast

import torch
from torch import nn

from darts.logging import get_logger, raise_log
from darts.models.components.chronos2_submodels import (
    _Chronos2Encoder,
    _InstanceNorm,
    _Patch,
    _ResidualBlock,
)
from darts.models.components.huggingface_connector import (
    HuggingFaceConnector,
)
from darts.models.forecasting.foundation_model import (
    FoundationModel,
)
from darts.models.forecasting.pl_forecasting_module import (
    PLForecastingModule,
)
from darts.utils.data.torch_datasets.utils import PLModuleInput, TorchTrainingSample
from darts.utils.likelihood_models.torch import QuantileRegression

logger = get_logger(__name__)


@dataclass
class _Chronos2ForecastingConfig:
    context_length: int
    output_patch_size: int
    input_patch_size: int
    input_patch_stride: int
    quantiles: list[float]
    use_reg_token: bool = False
    use_arcsinh: bool = False
    max_output_patches: int = 1
    time_encoding_scale: int | None = None


class _Chronos2Module(PLForecastingModule):
    def __init__(
        self,
        d_model: int = 512,
        d_kv: int = 64,
        d_ff: int = 2048,
        num_layers: int = 6,
        num_heads: int = 8,
        dropout_rate: float = 0.1,
        layer_norm_epsilon: float = 1e-6,
        feed_forward_proj: str = "relu",
        rope_theta: float = 10000.0,
        attn_implementation: Literal["eager", "sdpa"] | None = None,
        chronos_config: Optional[dict[str, Any]] = None,
        **kwargs,
    ):
        """PyTorch module implementing the Chronos-2 model, ported from
        `amazon-science/chronos-forecasting <https://github.com/amazon-science/chronos-forecasting>`_ and
        adapted for Darts :class:`PLForecastingModule` interface.

        Parameters
        ----------
        d_model
            Dimension of the model embeddings, also called "model size" in Transformer.
        d_kv
            Dimension of the key and value projections in multi-head attention.
        d_ff
            Dimension of the feed-forward network hidden layer.
        num_layers
            Number of Chronos-2 encoder layers.
        num_heads
            Number of attention heads in each encoder block.
        dropout_rate
            Dropout rate of the model.
        layer_norm_epsilon
            Epsilon value for layer normalization layers.
        feed_forward_proj
            Activation of feed-forward network.
        rope_theta
            Base period for Rotary Position Embeddings (RoPE).
        attn_implementation
            Attention implementation to use. If None, defaults to "sdpa".
        chronos_config
            Configuration parameters for Chronos-2 model. See :class:`_Chronos2ForecastingConfig` for details.
        **kwargs
            all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
            base class.
        """

        super().__init__(**kwargs)
        self.d_model = d_model
        self.d_kv = d_kv
        self.d_ff = d_ff
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.layer_norm_epsilon = layer_norm_epsilon
        self.feed_forward_proj = feed_forward_proj
        self.rope_theta = rope_theta

        act_info = self.feed_forward_proj.split("-")
        self.dense_act_fn = act_info[-1]
        self.is_gated_act = act_info[0] == "gated"

        if self.is_gated_act:
            raise_log(
                ValueError("gated activation is not supported"),
                logger,
            )

        # Attention implementation - default to "sdpa" if not specified
        self.attn_implementation = attn_implementation or "sdpa"
        if self.attn_implementation not in ["eager", "sdpa"]:
            raise_log(
                ValueError(
                    f"attn_implementation {self.attn_implementation} is not supported"
                ),
                logger,
            )

        # Chronos-2 forecasting specific config
        chronos_config = chronos_config or {}
        self.chronos_config = _Chronos2ForecastingConfig(**chronos_config)

        # Only decoder_start_id (and optionally REG token)
        if self.chronos_config.use_reg_token:
            self.reg_token_id = 1

        if (
            self.chronos_config.input_patch_size
            != self.chronos_config.output_patch_size
        ):
            raise_log(
                ValueError(
                    f"input_patch_size and output_patch_size sizes must be equal, "
                    f"but found {self.chronos_config.input_patch_size} and {self.chronos_config.output_patch_size}"
                ),
                logger,
            )

        self.vocab_size = 2 if self.chronos_config.use_reg_token else 1
        self.shared = nn.Embedding(self.vocab_size, self.d_model)

        # Input patch embedding layer
        self.input_patch_embedding = _ResidualBlock(
            # x3 for [time_embedding, patch, patch_mask]
            in_dim=self.chronos_config.input_patch_size * 3,
            h_dim=self.d_ff,
            out_dim=self.d_model,
            act_fn_name=self.dense_act_fn,
            dropout_p=self.dropout_rate,
        )

        # patching layer
        self.patch = _Patch(
            patch_size=self.chronos_config.input_patch_size,
            patch_stride=self.chronos_config.input_patch_stride,
        )

        # instance normalization, also referred to as "scaling" in Chronos and GluonTS
        self.instance_norm = _InstanceNorm(use_arcsinh=self.chronos_config.use_arcsinh)

        self.encoder = _Chronos2Encoder(
            d_model=self.d_model,
            d_kv=self.d_kv,
            d_ff=self.d_ff,
            num_heads=self.num_heads,
            dropout_rate=self.dropout_rate,
            rope_theta=self.rope_theta,
            attn_implementation=self.attn_implementation,
            dense_act_fn=self.dense_act_fn,
            layer_norm_epsilon=self.layer_norm_epsilon,
            is_gated_act=self.is_gated_act,
            num_layers=self.num_layers,
        )

        quantiles = self.chronos_config.quantiles
        self.num_quantiles = len(quantiles)
        quantiles_tensor = torch.tensor(quantiles)
        self.register_buffer("quantiles", quantiles_tensor, persistent=False)

        # gather indices of user-specified quantiles
        user_quantiles: list[float] = (
            self.likelihood.quantiles
            if isinstance(self.likelihood, QuantileRegression)
            else [0.5]
        )
        self.user_quantile_indices = [quantiles.index(q) for q in user_quantiles]

        self.output_patch_embedding = _ResidualBlock(
            in_dim=self.d_model,
            h_dim=self.d_ff,
            out_dim=self.num_quantiles * self.chronos_config.output_patch_size,
            act_fn_name=self.dense_act_fn,
            dropout_p=self.dropout_rate,
        )

    def _prepare_patched_context(
        self,
        context: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        context_mask = torch.isnan(context).logical_not().to(context.dtype)

        batch_size, _ = context.shape

        # scaling
        context, loc_scale = self.instance_norm(context)

        # scaling is done in 32-bit precision, then the context is moved to model's dtype
        context = context.to(self.dtype)
        context_mask = context_mask.to(self.dtype)

        # patching
        patched_context = self.patch(context)
        patched_mask = torch.nan_to_num(self.patch(context_mask), nan=0.0)
        patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0)

        # attention_mask = 1 if at least one item in the patch is observed
        attention_mask = patched_mask.sum(dim=-1) > 0  # (batch_size, num_patches)
        num_context_patches = attention_mask.shape[-1]

        # context time encoding: every observation is assigned a sequential time index,
        # scaled by model's context length = [-C, -(C-1), ..., -1] / context_length
        final_context_length = (
            num_context_patches * self.chronos_config.input_patch_size
        )
        context_time_enc = torch.arange(
            start=-final_context_length, end=0, device=self.device, dtype=torch.float32
        )
        context_time_enc = context_time_enc.div(
            cast(int, self.chronos_config.time_encoding_scale)
        ).to(self.dtype)
        context_time_enc = context_time_enc.view(
            1, num_context_patches, self.chronos_config.input_patch_size
        ).expand(
            batch_size,
            num_context_patches,
            self.chronos_config.input_patch_size,
        )

        # concat time encoding, context and mask along the last (feature) dim
        patched_context = torch.cat(
            [context_time_enc, patched_context, patched_mask], dim=-1
        )

        return patched_context, attention_mask, loc_scale

    def _prepare_patched_future(
        self,
        future_covariates: torch.Tensor,
        loc_scale: tuple[torch.Tensor, torch.Tensor],
        num_output_patches: int,
        batch_size: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        output_patch_size = self.chronos_config.output_patch_size
        future_covariates, _ = self.instance_norm(future_covariates, loc_scale)
        future_covariates = cast(torch.Tensor, future_covariates)
        future_covariates = future_covariates.to(self.dtype)

        future_covariates_mask = (
            torch.isnan(future_covariates).logical_not().to(future_covariates.dtype)
        )

        future_covariates = torch.where(
            future_covariates_mask > 0.0, future_covariates, 0.0
        )

        # add padding if the length of future_covariates is not an integer multiple of output_patch_size
        if num_output_patches * output_patch_size > future_covariates.shape[-1]:
            padding_shape = (
                *future_covariates.shape[:-1],
                num_output_patches * output_patch_size - future_covariates.shape[-1],
            )
            future_covariates = torch.cat(
                [
                    future_covariates,
                    torch.zeros(padding_shape).to(future_covariates),
                ],
                dim=-1,
            )
            future_covariates_mask = torch.cat(
                [
                    future_covariates_mask,
                    torch.zeros(padding_shape).to(future_covariates_mask),
                ],
                dim=-1,
            )

        patched_future_covariates = future_covariates.view(
            batch_size, num_output_patches, output_patch_size
        )
        patched_future_covariates_mask = future_covariates_mask.view(
            batch_size, num_output_patches, output_patch_size
        )

        # future time encoding: every future timestep is assigned a sequential time index,
        # scaled by model's context length = [0, 1, ..., h-1] / context_length
        final_future_length = num_output_patches * output_patch_size
        future_time_enc = torch.arange(
            start=0, end=final_future_length, device=self.device, dtype=torch.float32
        )
        future_time_enc = future_time_enc.div(
            cast(int, self.chronos_config.time_encoding_scale)
        ).to(self.dtype)
        future_time_enc = future_time_enc.view(
            1, num_output_patches, output_patch_size
        ).expand(
            batch_size,
            num_output_patches,
            output_patch_size,
        )

        patched_future = torch.cat(
            [
                future_time_enc,
                patched_future_covariates,
                patched_future_covariates_mask,
            ],
            dim=-1,
        )

        return patched_future, patched_future_covariates_mask

    def _forward(
        self,
        context: torch.Tensor,
        group_ids: torch.Tensor,
        future_covariates: torch.Tensor,
        num_output_patches: int = 1,
    ) -> torch.Tensor:
        """Original forward pass of the Chronos-2 model.

        Parameters
        ----------
        context
            Input tensor of shape (batch_size, context_length) containing the historical values
        group_ids : torch.Tensor | None, optional
            Group IDs of shape (batch_size,) indicating which times series in the batch form a group.
            A group indicates a task, for example, for a batch of size 6:
            - if groups_ids = [0, 1, 2, 3, 4, 5], each time series is treated independently.
            - if groups_ids = [0, 0, 1, 1, 1, 2], information is mixed across the first two time series (id=0),
                the next three time series (id=1) and the last time series is treated separately. Information is
                NOT shared among time series from different groups.
            The ordering and specific values of group_ids are not important, all time series with the same group
            ID form a group.
        future_covariates
            Tensor of shape (batch_size, future_length) containing future covariates. Note that the size of
            tensor along the first axis is equal to the batch_size. This means that future values (which may be NaNs)
            must be provided for each time series in the batch. For any time series that need to be forecasted, the
            future_covariates can be set to NaNs, if ``future_covariates_mask`` is omitted or to an arbitrary dummy
            value when ``future_covariates_mask`` is provided. ``future_covariates`` can be used with ``group_ids``
            to construct heterogenous forecasting tasks in a single batch. For example:
            - future_covariates = [[nan, ...], [nan, ...], [v1, ...], [v2, ...], [nan, ...], [nan, ...]]
            - groups_ids = [0, 0, 1, 1, 1, 2]
            - future_covariates_mask = None
            contains 3 types of forecasting tasks:
            - [0, 0]: The first task, both future_covariates are missing, which implies that the two time series need to
                be forecasted jointly, i.e., multivariate forecasting.
            - [1, 1, 1]: In the next task, the first two future_covariates are available and the last one is missing
                ([v1, ...], [v2, ...], [nan, ...]), where [v1, ...] and [v1, ...] denote an arbitrary sequence of
                values. This indicates that the first two time series are known covariates and the third one needs to be
                forecasted by the model.
            - [2]: The last task has a single time series in the group which needs to be forecasted independently.
            There is no theoretical limit on the number of time series in a group, i.e., the number of targets and known
            covariates in a task. The above setup subsumes tasks with past-only covariates as the model's prediction for
            those time series can simply be ignored downstream.
        num_output_patches
            Number of output patches to generate predictions for, by default 1
            When ``future_covariates`` and/or ``future_target`` are provided, num_output_patches should be large enough
            to accommodate their lengths, i.e., num_output_patches * output_patch_size >= future_length

        Returns
        -------
        torch.Tensor
            Quantile predictions of shape `(batch_size, n_variables * n_output_patches * n_quantiles * patch_size)`.
            quantile_preds will contain an entry for every time series in the context batch regardless of whether it
            was a known future covariate.
        """

        batch_size = context.shape[0]
        patched_context, attention_mask, loc_scale = self._prepare_patched_context(
            context=context
        )
        num_context_patches = attention_mask.shape[-1]

        # get input embeddings of shape (batch, num_context_patches, d_model)
        input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
        # append [REG] special token embedding, if needed
        if self.chronos_config.use_reg_token:
            reg_input_ids = torch.full(
                (batch_size, 1), self.reg_token_id, device=input_embeds.device
            )
            reg_embeds = self.shared(reg_input_ids)
            input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
            attention_mask = torch.cat(
                [
                    attention_mask.to(self.dtype),
                    torch.ones_like(reg_input_ids).to(self.dtype),
                ],
                dim=-1,
            )

        patched_future, _ = self._prepare_patched_future(
            future_covariates=future_covariates,
            loc_scale=loc_scale,
            num_output_patches=num_output_patches,
            batch_size=batch_size,
        )
        future_attention_mask = torch.ones(
            batch_size,
            num_output_patches,
            dtype=attention_mask.dtype,
            device=self.device,
        )

        # get future embeddings of shape (batch, num_output_patches, d_model)
        future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)

        # concatenate context and future embeddings and masks
        input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
        attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)

        hidden_states: torch.Tensor = self.encoder(
            attention_mask=attention_mask,
            inputs_embeds=input_embeds,
            group_ids=group_ids,
        )

        assert hidden_states.shape == (
            batch_size,
            num_context_patches + 1 + num_output_patches,
            self.d_model,
        )

        # slice the last num_output_patches hidden states to be input into the output_patch_embedding
        forecast_embeds = hidden_states[:, -num_output_patches:]
        quantile_preds: torch.Tensor = self.output_patch_embedding(forecast_embeds)

        quantile_preds = quantile_preds.view(batch_size, -1)
        quantile_preds = self.instance_norm.inverse(quantile_preds, loc_scale)

        return quantile_preds

    # TODO: fine-tuning support w/ normalized loss
    # Currently, Darts own `RINorm` is not used as Chronos-2 has its own implementation. Major differences
    # 1. Chronos-2 `RINorm` normalizes both target and covariates, while Darts normalizes target only.
    # 2. Chronos-2 `RINorm` additionally applies `arcsinh` transformation after standardization
    # 3. Chronos-2 uses normalized values for loss computation, while Darts uses denormalized values.
    # We need to think about how best to implement Chronos-2 `RINorm` in `io_processor()` without
    # breaking existing behavior, while also allowing fine-tuning with normalized loss.
    def forward(self, x_in: PLModuleInput, *args, **kwargs) -> Any:
        """Chronos-2 model forward pass.

        Parameters
        ----------
        x_in
            comes as tuple `(x_past, x_future, x_static)` where `x_past` is the input/past chunk and `x_future`
            is the output/future chunk. Input dimensions are `(n_samples, n_time_steps, n_variables)`

        Returns
        -------
        torch.Tensor
            the output tensor in the shape of `(n_samples, n_time_steps, n_targets, n_quantiles)` for
            probabilistic forecasts, or `(n_samples, n_time_steps, n_targets, 1)` for
            deterministic forecasts (median only).
        """
        x_past, x_future, _ = x_in
        # x_past is a stack of [past_target, past_covariates, historic_future_covariates],
        # x_future is just future_covariates.
        # So here we need to create `future_covariates` in Chronos2's format that is
        # a stack of [past_target (NaNs), past_covariates (NaNs), future_covariates].
        batch_size, past_length, n_variables = x_past.shape
        output_chunk_length = self.output_chunk_length or 0
        output_chunk_shift = self.output_chunk_shift
        future_length = output_chunk_shift + output_chunk_length
        future_covariates = torch.full(
            (batch_size, future_length, n_variables),
            torch.nan,
            device=x_past.device,
        )
        if x_future is not None:
            n_future_covs = x_future.shape[-1]
            future_covariates[:, -output_chunk_length:, -n_future_covs:] = x_future

        # reshape x_past and future_covariates to (batch * vars, time)
        context = x_past.permute(0, 2, 1).reshape(-1, past_length)
        future_covariates = future_covariates.permute(0, 2, 1).reshape(
            -1, future_length
        )

        # create group_ids according to sample index within the batch
        group_ids = torch.arange(batch_size, device=context.device).repeat_interleave(
            n_variables
        )

        # determine minimum number of patches to cover future_length
        num_output_patches = math.ceil(
            future_length / self.chronos_config.output_patch_size
        )

        # call original Chronos-2 forward pass
        # Unlike the original, we remove `context_mask`, `future_covariates_mask`, `future_target`,
        # `future_target_mask`, and `output_attentions` parameters. They are not needed for Darts'
        # implementation.
        # We also remove `einops` rearrange operation at the end so the raw output tensor is returned,
        # in shape of `(batch, vars * patches * quantiles * patch_size)`
        quantile_preds = self._forward(
            context=context,
            group_ids=group_ids,
            future_covariates=future_covariates,
            num_output_patches=num_output_patches,
        )

        # The permutation and reshaping operations below replace the `einops` rearrange
        # operations in the original Chronos-2 code to return the output tensor in Darts'
        # expected shape.
        # reshape quantile_preds to (batch, vars, patches, quantiles, patch_size)
        quantile_preds = quantile_preds.view(
            batch_size,
            n_variables,
            num_output_patches,
            self.num_quantiles,
            self.chronos_config.output_patch_size,
        )
        # permute and reshape to (batch, time, vars, quantiles)
        quantile_preds = quantile_preds.permute(0, 2, 4, 1, 3).reshape(
            batch_size,
            num_output_patches * self.chronos_config.output_patch_size,
            n_variables,
            self.num_quantiles,
        )

        # truncate to output_chunk_length
        quantile_preds = quantile_preds[:, output_chunk_shift:future_length, :, :]

        # select only target variables
        quantile_preds = quantile_preds[:, :, : self.n_targets, :]

        # select only user-specified quantiles or median if deterministic
        quantile_preds = quantile_preds[:, :, :, self.user_quantile_indices]

        return quantile_preds


class Chronos2Model(FoundationModel):
    # Fine-tuning is turned off for now pending proper fine-tuning support
    # and configuration.
    _allows_finetuning = False

    def __init__(
        self,
        input_chunk_length: int,
        output_chunk_length: int,
        output_chunk_shift: int = 0,
        likelihood: Optional[QuantileRegression] = None,
        hub_model_name: str = "amazon/chronos-2",
        hub_model_revision: Optional[str] = None,
        local_dir: Optional[Union[str, os.PathLike]] = None,
        **kwargs,
    ):
        """Chronos-2 Model for zero-shot forecasting.

        This is an implementation of Amazon's Chronos-2 model [1]_, [2]_, ported from
        `amazon-science/chronos-forecasting <https://github.com/amazon-science/chronos-forecasting>`_
        with adaptations to use the Darts API. From the original authors:

        "Chronos-2 is a 120M-parameter, encoder-only time series foundation model for zero-shot forecasting. It supports
        univariate, multivariate, and covariate-informed tasks within a single architecture. Inspired by the T5 encoder,
        Chronos-2 produces multi-step-ahead quantile forecasts and uses a group attention mechanism for efficient
        in-context learning across related series and covariates. Trained on a combination of real-world and large-scale
        synthetic datasets, it achieves state-of-the-art zero-shot accuracy among public models on fev-bench, GIFT-Eval,
        and Chronos Benchmark II. Chronos-2 is also highly efficient, delivering over 300 time series forecasts per
        second on a single A10G GPU and supporting both GPU and CPU inference."

        This model supports past covariates (known for `input_chunk_length` points before prediction time),
        and future covariates (known for `output_chunk_length` points after prediction time).

        By default, using this model will automatically download and cache the pre-trained model from HuggingFace Hub
        (amazon/chronos-2). Alternatively, you can specify a local directory containing the model config and weights
        using the ``local_dir`` parameter.

        Two other variants of Chronos-2 are available on HuggingFace Hub:

        - `autogluon/chronos-2-small <https://huggingface.co/autogluon/chronos-2-small>`_: a smaller 28M parameter
          Chronos-2 model.
        - `autogluon/chronos-2-synth <https://huggingface.co/autogluon/chronos-2-synth>`_: a 120M parameter
          Chronos-2 model trained on synthetic data only.

        To use either of those variants, specify the ``hub_model_name`` parameter to the desired model ID.

        By default, this model is deterministic and outputs only the median (0.5 quantile). To enable probabilistic
        forecasts, pass a :class:`~darts.utils.likelihood_models.torch.QuantileRegression` instance to the
        ``likelihood`` parameter. The quantiles used must be a subset of those used during Chronos-2 pre-training, see
        below for details. It is recommended to call :func:`predict()` with ``predict_likelihood_parameters=True``
        or ``num_samples >> 1`` to get meaningful results.

        Parameters
        ----------
        input_chunk_length
            Number of time steps in the past to take as a model input (per chunk). Applies to the target
            series, and past and/or future covariates (if the model supports it).
            Maximum is 8192 for Chronos-2.
        output_chunk_length
            Number of time steps predicted at once (per chunk) by the internal model. Also, the number of future values
            from future covariates to use as a model input (if the model supports future covariates). It is not the same
            as forecast horizon `n` used in `predict()`, which is the desired number of prediction points generated
            using either a one-shot- or autoregressive forecast. Setting `n <= output_chunk_length` prevents
            auto-regression. This is useful when the covariates don't extend far enough into the future, or to prohibit
            the model from using future values of past and / or future covariates for prediction (depending on the
            model's covariate support).
            For Chronos-2, `output_chunk_length + output_chunk_shift` must be less than or equal to 1024.
        output_chunk_shift
            Optionally, the number of steps to shift the start of the output chunk into the future (relative to the
            input chunk end). This will create a gap between the input and output. If the model supports
            `future_covariates`, the future values are extracted from the shifted output chunk. Predictions will start
            `output_chunk_shift` steps after the end of the target `series`. If `output_chunk_shift` is set, the model
            cannot generate autoregressive predictions (`n > output_chunk_length`).
        likelihood
            The likelihood model to be used for probabilistic forecasts. Must be ``None`` or an instance of
            :class:`~darts.utils.likelihood_models.torch.QuantileRegression`. If using ``QuantileRegression``,
            the quantiles must be a subset of those used during Chronos-2 pre-training:
            [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9,
            0.95, 0.99].
            Default: ``None``, which will make Chronos-2 deterministic (median quantile only).
        hub_model_name
            The model ID on HuggingFace Hub. Default: ``"amazon/chronos-2"``. Other available variants include
            ``"autogluon/chronos-2-small"`` and ``"autogluon/chronos-2-synth"``.
        hub_model_revision
            The model version to use. This can be a branch name, tag name, or commit hash. Default is ``None``, which
            will use the default branch from ``hub_model_name``.
        local_dir
            Optional local directory to load the pre-downloaded model. If specified and the directory is empty, the
            model will be downloaded from HuggingFace Hub and saved to this directory. Default is ``None``, which will
            use a cache directory managed by ``huggingface_hub`` instead. Note that this is different from the
            ``work_dir`` parameter used for saving model checkpoints during fine-tuning.
        **kwargs
            Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
            Darts' :class:`TorchForecastingModel`.

        loss_fn
            PyTorch loss function used for fine-tuning a deterministic Chronos-2 model. Ignored for probabilistic
            Chronos-2 when ``likelihood`` is specified. Default: ``nn.MSELoss()``.
        torch_metrics
            A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
            at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
        optimizer_cls
            The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``.
        optimizer_kwargs
            Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}``
            for specifying a learning rate). Otherwise, the default values of the selected ``optimizer_cls``
            will be used. Default: ``None``.
        lr_scheduler_cls
            Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds
            to using a constant learning rate. Default: ``None``.
        lr_scheduler_kwargs
            Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
        use_reversible_instance_norm
            Whether to use reversible instance normalization `RINorm` against distribution shift. Ignored by
            Chronos-2 as it has its own `RINorm` implementation.
        batch_size
            Number of time series (input and output sequences) used in each training pass. Default: ``32``.
        n_epochs
            Number of epochs over which to train the model. Default: ``100``.
        model_name
            Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
            defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
            of the name is formatted with the local date and time, while PID is the processed ID (preventing models
            spawned at the same time by different processes to share the same model_name). E.g.,
            ``"2021-06-14_09_53_32_torch_model_run_44607"``.
        work_dir
            Path of the working directory, where to save checkpoints and Tensorboard summaries.
            Default: current working directory.
        log_tensorboard
            If set, use Tensorboard to log the different parameters. The logs will be located in:
            ``"{work_dir}/darts_logs/{model_name}/logs/"``. Default: ``False``.
        nr_epochs_val_period
            Number of epochs to wait before evaluating the validation loss (if a validation
            ``TimeSeries`` is passed to the :func:`fit()` method). Default: ``1``.
        force_reset
            If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
            be discarded). Default: ``False``.
        save_checkpoints
            Whether to automatically save the untrained model and checkpoints from training.
            To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where
            :class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`,
            :class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using
            :func:`save()` and loaded using :func:`load()`. Default: ``False``.
        add_encoders
            A large number of past and future covariates can be automatically generated with `add_encoders`.
            This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
            will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to
            transform the generated covariates. This happens all under one hood and only needs to be specified at
            model creation.
            Read :meth:`SequentialEncoder <darts.dataprocessing.encoders.SequentialEncoder>` to find out more about
            ``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features:

            .. highlight:: python
            .. code-block:: python

                def encode_year(idx):
                    return (idx.year - 1950) / 50

                add_encoders={
                    'cyclic': {'future': ['month']},
                    'datetime_attribute': {'future': ['hour', 'dayofweek']},
                    'position': {'past': ['relative'], 'future': ['relative']},
                    'custom': {'past': [encode_year]},
                    'transformer': Scaler(),
                    'tz': 'CET'
                }
            ..
        random_state
            Controls the randomness of the weights initialization and reproducible forecasting.
        pl_trainer_kwargs
            By default :class:`TorchForecastingModel` creates a PyTorch Lightning Trainer with several useful presets
            that performs the training, validation and prediction processes. These presets include automatic
            checkpointing, tensorboard logging, setting the torch device and more.
            With ``pl_trainer_kwargs`` you can add additional kwargs to instantiate the PyTorch Lightning trainer
            object. Check the `PL Trainer documentation
            <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for more information about the
            supported kwargs. Default: ``None``.
            Running on GPU(s) is also possible using ``pl_trainer_kwargs`` by specifying keys ``"accelerator",
            "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs``
            dict:

            - ``{"accelerator": "cpu"}`` for CPU,
            - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer),
            - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS.

            For more info, see here:
            https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags , and
            https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_basic.html#train-on-multiple-gpus

            With parameter ``"callbacks"`` you can add custom or PyTorch-Lightning built-in callbacks to Darts'
            :class:`TorchForecastingModel`. Below is an example for adding EarlyStopping to the training process.
            The model will stop training early if the validation loss `val_loss` does not improve beyond
            specifications. For more information on callbacks, visit:
            `PyTorch Lightning Callbacks
            <https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html>`__

            .. highlight:: python
            .. code-block:: python

                from pytorch_lightning.callbacks.early_stopping import EarlyStopping

                # stop training when validation loss does not decrease more than 0.05 (`min_delta`) over
                # a period of 5 epochs (`patience`)
                my_stopper = EarlyStopping(
                    monitor="val_loss",
                    patience=5,
                    min_delta=0.05,
                    mode='min',
                )

                pl_trainer_kwargs={"callbacks": [my_stopper]}
            ..

            Note that you can also use a custom PyTorch Lightning Trainer for training and prediction with optional
            parameter ``trainer`` in :func:`fit()` and :func:`predict()`.
        show_warnings
            whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
            your forecasting use case. Default: ``False``.

        References
        ----------
        .. [1] A. Ansari, O. Shchur, J. Küken et al. "Chronos-2: From Univariate to Universal Forecasting", 2025.
                arXiv https://arxiv.org/abs/2510.15821.
        .. [2] "Introducing Chronos-2: From univariate to universal forecasting", 2025. Amazon Science Blog.
                https://www.amazon.science/blog/introducing-chronos-2-from-univariate-to-universal-forecasting

        Examples
        --------
        >>> from darts.datasets import WeatherDataset
        >>> from darts.models import Chronos2Model
        >>> # load data in float32 format (macOS issues with float64 and PyTorch)
        >>> series = WeatherDataset().load().astype("float32")
        >>> # predicting atmospheric pressure
        >>> target = series['p (mbar)'][:100]
        >>> # optionally, use past observed rainfall (pretending to be unknown beyond index 100)
        >>> past_cov = series['rain (mm)'][:100]
        >>> # optionally, use future temperatures (pretending this component is a forecast)
        >>> future_cov = series['T (degC)'][:106]
        >>> # by default, Chronos2Model is deterministic; to enable probabilistic forecasts,
        >>> # set likelihood to QuantileRegression and use a subset of the pre-trained quantiles
        >>> model = Chronos2Model(
        >>>     input_chunk_length=6,
        >>>     output_chunk_length=6,
        >>> )
        >>> # calling fit is still mandatory to ensure consistent number of components; however,
        >>> # Chronos2Model is training-free and the model weights are not updated
        >>> model.fit(target, past_covariates=past_cov, future_covariates=future_cov)
        >>> # when Chronos2Model is probabilistic, set ``predict_likelihood_parameters=True``
        >>> # or ``num_samples>>1`` to get meaningful results
        >>> pred = model.predict(6)
        >>> print(pred.all_values())
        [[[1005.7576 ]]
        [[1005.7418 ]]
        [[1005.7186 ]]
        [[1005.7074 ]]
        [[1005.6928 ]]
        [[1005.69617]]]

        .. note::
            Fine-tuning of Chronos-2 is not supported at the moment.
        .. note::
            Chronos-2 is licensed under the `Apache-2.0 License <https://github.com/amazon-science/chronos-forecasting/blob/main/LICENSE>`_,
            copyright Amazon.com, Inc. or its affiliates. By using this model, you agree to the terms and conditions of
            the license.
        .. warning::
            Due to differences in probabilistic sampling methods, zero-shot forecasts obtained here would differ from
            those obtained using the original implementation when prediction horizon `n` is larger than 1024.
        """
        hf_connector = HuggingFaceConnector(
            model_name=hub_model_name,
            model_revision=hub_model_revision,
            local_dir=local_dir,
        )

        # load model config for validation
        config = hf_connector.load_config()
        chronos_config = config["chronos_config"]

        # validate `input_chunk_length` against model's context_length
        context_length = chronos_config["context_length"]
        if input_chunk_length > context_length:
            raise_log(
                ValueError(
                    f"`input_chunk_length` {input_chunk_length} cannot be greater than "
                    f"model's context_length {context_length}"
                ),
                logger,
            )

        # validate `output_chunk_length` and `output_chunk_shift` against model's prediction length
        prediction_length = (
            chronos_config["output_patch_size"] * chronos_config["max_output_patches"]
        )
        if output_chunk_length + output_chunk_shift > prediction_length:
            raise_log(
                ValueError(
                    f"`output_chunk_length` {output_chunk_length} plus `output_chunk_shift` {output_chunk_shift} "
                    f"cannot be greater than model's maximum prediction length {prediction_length}"
                ),
                logger,
            )

        quantiles = chronos_config["quantiles"]
        # by default (`likelihood=None`), model is deterministic
        # otherwise, only QuantileRegression likelihood is supported and quantiles must be
        # a subset of Chronos-2 quantiles
        if likelihood is not None:
            if not isinstance(likelihood, QuantileRegression):
                raise_log(
                    ValueError(
                        f"Only QuantileRegression likelihood is supported for Chronos2Model in Darts. "
                        f"Got {type(likelihood)}."
                    ),
                    logger,
                )
            user_quantiles: list[float] = likelihood.quantiles
            if not set(user_quantiles).issubset(quantiles):
                raise_log(
                    ValueError(
                        f"The quantiles for QuantileRegression likelihood {user_quantiles} "
                        f"must be a subset of Chronos-2 quantiles {quantiles}."
                    ),
                    logger,
                )

        self.hf_connector = hf_connector
        super().__init__(enable_finetuning=False, **kwargs)

    def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
        pl_module_params = self.pl_module_params or {}
        return self.hf_connector.load_model(
            module_class=_Chronos2Module,
            pl_module_params=pl_module_params,
        )
